'''
和basic区别是,MOE选择topK个专家，然后对这topK个专家的输出进行加权求和，并且把输入样本变成了大模型中真实的输入Shape，（batch, seq_len, hidden_dim）
'''
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicExpert(nn.Module):
    def __init__(self, feature_in, feature_out):
        super().__init__()
        self.fc = nn.Linear(feature_in, feature_out)

    def forward(self, x):
        return self.fc(x)

# 上面的BasicExpert可替换为带有 swishGLU 的 FFN 代码实现，如下所示：
# class FFNExpert(nn.Module):
#     def __init__(self, hidden_dim, dropout):   # LLM 进化之路， FFN 激活函数从 GELU -> SwishGLU
#         super().__init__()
#
#         # 有一个 magic number 叫做 8/3
#         hidden_dim = hidden_dim
#         # 这里可以自己去优化成 multiple_of 的倍数
#         mid_dim = hidden_dim * 8 // 3
#
#         # 一共有三个可训练的参数矩阵, w1, w2, w3
#         self.up = nn.Linear(hidden_dim, mid_dim, bias=False)
#         self.down = nn.Linear(mid_dim, hidden_dim, bias=False)
#         self.gate = nn.Linear(hidden_dim, mid_dim, bias=False)
#
#         self.dropout = nn.Dropout(dropout)
#
#     def forward(self, x):
#         '''
#         swish 是一个非线性函数（激活函数都是如此），具体公式为：
#         swish = x * θ * （β * x）
#         其中 β 是一个超参数，当 β = 1 时，Swish 就变成了 SiLU (Sigmoid Linear Unit)，大多数框架的默认实现（如 PyTorch、TensorFlow 的 nn.SiLU()）使用的是
#         的固定版本。
#         '''
#         out = self.dropout(
#             self.down(
#                 # up 之后的 Shape 是(b, s, mid_dim)
#                 # gate 和 up 之后的Shape都是 (b, s, mid_dim)
#                 # 两者是 element-wise 相乘
#                 self.gate(x) * F.silu(self.up(x))
#             )
#         )
#         return out

# 主要参考自 mistral MOE 的实现
class MOERouter(nn.Module):
    def __init__(self, hidden_dim, expert_number, top_k):
        super().__init__()
        self.gate = nn.Linear(hidden_dim, expert_number)

        # 但是后面只会选top_k个专家
        self.expert_number = expert_number
        self.top_k = top_k

    def forward(self, hidden_states):
        # 计算路由logits
        router_logits = self.gate(hidden_states)  # shape is (b * s, expert_number)

        # 计算专家经过softmax之后的概率
        # 最后一维是8
        routing_probs = F.softmax(router_logits, dim=-1, dtype=torch.float)

        # 计算topk的专家的输出
        # topK是可以反向传播的
        # 最后一维是2
        router_weights, selected_experts = torch.topk(
            routing_probs, self.top_k, dim=-1
        )  # shape都是 (b * s, top_k)

        # 专家权重归一化
        router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)
        router_weights = router_weights.to(hidden_states.dtype)

        # 生成专家掩码
        expert_mask = F.one_hot(
            selected_experts,
            num_classes=self.expert_number
        )  # shape是 (b * s, top_k, expert_number)
        expert_mask = expert_mask.permute(2, 1, 0)  # (expert_number, top_k, b * s)

        return router_logits, router_weights, selected_experts, expert_mask


class MOEConfig:
    def __init__(
            self,
            hidden_dim,
            expert_number,
            top_k,
            shared_experts_number=2,
    ):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_experts_number = shared_experts_number


class SparseMOE(nn.Module):
    # 稀疏 MOE 模型，这里每一个 token 都会过 topk 个专家，得到对应token 的 hidden_embeddings
    def __init__(self, config):
        super().__init__()

        self.hidden_dim = config.hidden_dim

        self.expert_number = config.expert_number
        self.top_k = config.top_k

        self.experts = nn.ModuleList(
            [
                BasicExpert(self.hidden_dim, self.hidden_dim) for _ in range(self.expert_number)
            ]
        )

        self.router = MOERouter(self.hidden_dim, self.expert_number, self.top_k)

    def forward(self, x):
        # x shape is (b, s, hidden_dim)
        batch_size, seq_len, hidden_dim = x.size()

        # 合并前两个维度，因为不是 Sample 维度了，而是 token 维度
        hidden_states = x.view(-1, hidden_dim)  # shape is(b * s, hidden_dim)

        router_logits, router_weights, selected_experts_indices, expert_mask = self.router(hidden_states)
        # 其中 router_logits shape是(b * s, expert_number)
        # 其中 router_weights shape是(b * s, top_k)，top_k的值代表专家的权重
        # 其中 selected_experts_indices shape 是 (b * s, top_k)
        # 其中 expert_mask shape 是 (expert_number, top_k, b * s)

        # 最终输出的shape肯定是(b * s, hidden_dim)，这里先做一个初始化
        final_hidden_states = torch.zeros(
            (batch_size * seq_len, hidden_dim),
            dtype=hidden_states.dtype,
            device=hidden_states.device
        )

        # 遍历每一个专家
        # 把选中的这个专家的token的hidden_states加到final_hidden_states中
        # expert-0可能有100个token选中了
        # token的总数是b * s
        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]
            # 其中 expert_mask shape 是 (expert_number, top_k, b * s)
            # expert_mask[expert_idx] shape 是 (top_k, b * s)
            idx, top_x = torch.where(expert_mask[expert_idx])
            # idx 和 top_x 都是一维 tensor
            # idx 的值表示当前专家是排名第几位的专家
            # top_x 的值是 token 在 batch*seq_len 中的位置索引
            # 例如对于 batch_size=2, seq_len=4 的输入:
            # top_x 的值范围是 0-7, 表示在展平后的 8 个 token 中的位置

            # idx 是用来选weight的
            # top_x 用来选token的hidden_states

            # hidden_states 的 shape 是 (b * s, hidden_dim)
            # 需要取到 top_x 对应的 hidden_states
            current_state = hidden_states.unsqueeze(
                0
            )[:, top_x, :].reshape(-1, hidden_dim)  # （selected_token_number, hidden_dim）

            # router_weight 的 shape 是 (b * s, top_k)
            # router_weights[top_x, idx]表示选中的这top_x个token的索引为idx的专家的权重
            current_hidden_states = expert_layer(
                current_state
            ) * router_weights[top_x, idx].unsqueeze(-1)  # （selected_token_number, 1） 这里有广播

            # 把当前专家的输出加到 final_hidden_states 中
            # 方式1 的写法性能更好，并且方式1容易出现
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
            # 方式2
            # final_hidden_states[top_x] += current_hidden_states.to(hidden_states.dtype)
            # 方式1 的写法性能更差，并且方式1容易出现错误，+= 操作在处理重复索引时需要多次读写内存，可能会导致竞争条件

        # 把 final_hidden_states 还原到原来的 shape
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)

        return final_hidden_states, router_logits  # shape 是 (b * s, expert_number)


# def test_token_level_moe():
#     x = torch.rand(2, 4, 16)
#     config = MOEConfig(16, 2, 2)
#     token_level_moe = SparseMOE(config)
#     out = token_level_moe(x)
#     print(out[0].shape, out[1].shape)
#
#
# test_token_level_moe()
'''
以上所有内容替换的是传统gpt2架构中的FeedForward层
'''